# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import gym
from gym.spaces.box import Box
import omegaconf
import torch
import torch.nn as nn
from torch.nn.modules.linear import Identity
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
from pathlib import Path
import pickle
from torchvision.utils import save_image
import hydra
import clip
import torchvision
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC


def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

def _get_embedding(embedding_name='resnet34', load_path="", *args, **kwargs):
    if load_path == "random":
        prt = False
    else:
        prt = True
    if embedding_name == 'resnet34':
        model = models.resnet34(pretrained=prt, progress=False)
        embedding_dim = 512
    elif embedding_name == 'resnet18':
        model = models.resnet18(pretrained=prt, progress=False)
        embedding_dim = 512
    elif embedding_name == 'resnet50':
        model = models.resnet50(pretrained=prt, progress=False)
        embedding_dim = 2048
    else:
        print("Requested model not available currently")
        raise NotImplementedError
    # make FC layers to be identity
    # NOTE: This works for ResNet backbones but should check if same
    # template applies to other backbone architectures
    model.fc = Identity()
    model = model.eval()
    return model, embedding_dim


class ClipEnc(nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
    def forward(self, im, lang):
        e = self.m.encode_image(im)
        l = self.m.encode_text(clip.tokenize([lang]).to('cuda'))
        return e, l

class DecisionnceEnc(nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
    def forward(self, im, lang):
        e = self.m.encode_image_eval(im)
        l = self.m.encode_text(lang)
        return e, l
    
class LivEnc(nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
        self.transform = T.Compose([T.ToTensor()])
    def forward(self, im, lang):
        e = self.m(input=im, modality="vision")
        token = clip.tokenize([lang])
        goal_embedding_text = self.m(input=token, modality="text")
        l = goal_embedding_text[0] 
        return e, l
    
class AcTOLEnc(nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
    def forward(self, im, lang):
        e = self.m.encode_image_eval(im)
        l = self.m.encode_text(lang)
        return e, l   
    
class r3mEnc(nn.Module):
    def __init__(self, m, l):
        super().__init__()
        self.m = m
        self.l = l
    def forward(self, im, lang):
        e = self.m(im)
        l = self.l(lang)
        return e, l 
    
class StateEmbedding(gym.ObservationWrapper):
    """
    This wrapper places a convolution model over the observation.

    From https://pytorch.org/vision/stable/models.html
    All pre-trained models expect input images normalized in the same way,
    i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
    where H and W are expected to be at least 224.

    Args:
        env (Gym environment): the original environment,
        embedding_name (str, 'baseline'): the name of the convolution model,
        device (str, 'cuda'): where to allocate the model.

    """
    def __init__(self, env, embedding_name=None, device='cuda', load_path="", proprio=0, camera_name=None, env_name=None, lang_cond=False, lang=None):
        gym.ObservationWrapper.__init__(self, env)
        self.lang_cond = lang_cond
        self.lang = lang
        self.proprio = proprio
        self.load_path = load_path
        self.start_finetune = False
        if load_path == "clip":
            import clip
            model, cliptransforms = clip.load("RN50", device="cuda")
            embedding = ClipEnc(model)
            embedding.eval()
            embedding_dim = 1024
            lang_dim = 1024
            self.transforms = cliptransforms
        elif load_path == "clipvit":
            import clip
            model, cliptransforms = clip.load("ViT-B/32", device="cuda")
            embedding = ClipEnc(model)
            embedding.eval()
            embedding_dim = 512
            lang_dim = 512
            self.transforms = cliptransforms
        elif (load_path == "random") or (load_path == ""):
                embedding, embedding_dim = _get_embedding(embedding_name=embedding_name, load_path=load_path)
                self.transforms = T.Compose([T.Resize(256),
                            T.CenterCrop(224),
                            T.ToTensor(), # ToTensor() divides by 255
                            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        elif "r3m" == load_path:
            from r3m import load_r3m_reproduce
            from r3m.models.models_language import LangEncoder, LanguageReward
            lang_enc = LangEncoder(device, 0, 0) 
            rep = load_r3m_reproduce("r3m", device)
            model = r3mEnc(rep, lang_enc)
            model.eval()
            embedding_dim = rep.module.outdim
            lang_dim = 768
            embedding = model
            self.transforms = T.Compose([T.Resize(256),
                        T.CenterCrop(224),
                        T.ToTensor()]) # ToTensor() divides by 255
        elif "dt" == load_path:
            import DecisionNCE, clip
            _, cliptransforms = clip.load("RN50", device="cuda")
            model = DecisionNCE.load("DecisionNCE-T", device="cuda")
            embedding = DecisionnceEnc(model)
            embedding_dim = 1024
            lang_dim = 1024
            self.transforms = cliptransforms
        elif "dtvit" == load_path:
            import DecisionNCE, clip
            _, cliptransforms = clip.load("ViT-B/32", device="cuda")
            model = DecisionNCE.load("DecisionNCE-T-ViT", device="cuda")
            embedding = DecisionnceEnc(model)
            embedding_dim = 512
            lang_dim = 512
            self.transforms = cliptransforms
        elif "dp" == load_path:
            import DecisionNCE, clip
            _, cliptransforms = clip.load("RN50", device="cuda")
            model = DecisionNCE.load("DecisionNCE-P", device="cuda")
            embedding = DecisionnceEnc(model)
            embedding_dim = 1024
            lang_dim = 1024
            self.transforms = cliptransforms
        elif "actol" == load_path:
            import AcTOL, clip 
            model, cliptransforms = clip.load("RN50", device="cuda")
            model = AcTOL.load("actol", device="cuda")
            embedding = AcTOLEnc(model)
            embedding.eval()
            embedding_dim = 1024
            lang_dim = 1024
            self.transforms = cliptransforms
        elif "liv" == load_path:
            from liv import load_liv
            import clip
            _, cliptransforms = clip.load("RN50", device="cuda")
            model = load_liv()
            embedding = LivEnc(model)
            embedding.eval()
            embedding_dim = 1024
            lang_dim = 1024
            self.transforms = cliptransforms
            # self.transforms = self.transforms_tensor = nn.Sequential(
            #     T.Resize(224, interpolation=BICUBIC,antialias=None),
            #     T.CenterCrop(224),
            #     T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
            # )
            # self.transforms = T.Compose([T.ToTensor()])
        else:
            raise NameError("Invalid Model")
        embedding.eval()

        if device == 'cuda' and torch.cuda.is_available():
            print('Using CUDA.')
            device = torch.device('cuda')
        else:
            print('Not using CUDA.')
            device = torch.device('cpu')
        self.device = device
        embedding.to(device=device)

        self.embedding, self.lang_dim, self.embedding_dim = embedding, lang_dim, embedding_dim
        if self.lang_cond:
            self.observation_space = Box(
                    low=-np.inf, high=np.inf, shape=(self.embedding_dim+self.lang_dim+self.proprio,))
        else:
            self.observation_space = Box(
                        low=-np.inf, high=np.inf, shape=(self.embedding_dim+self.proprio,))
        
    def observation(self, observation):
        ### INPUT SHOULD BE [0,255]
        if self.embedding is not None:
            # if "dt" or "dp" in self.load_path:
            #     with torch.no_grad():
            #         emb = self.embedding(observation).view(-1, self.embedding_dim).to('cpu').numpy().squeeze()
            # else: # r3m or clip
            inp = self.transforms(Image.fromarray(observation.astype(np.uint8))).reshape(-1, 3, 224, 224)
            if "r3m" in self.load_path:
                ## R3M Expects input to be 0-255, preprocess makes 0-1
                inp *= 255.0
            inp = inp.to(self.device)
            with torch.no_grad():
                img_emb, lang_emb = self.embedding(inp, self.lang)
                img_emb = img_emb.view(-1, self.embedding_dim).to('cpu').numpy().squeeze()
                if self.lang_cond:
                    lang_emb = lang_emb.view(-1, self.lang_dim).to('cpu').numpy().squeeze()

            ## IF proprioception add it to end of embedding
            if self.proprio:
                try:
                    proprio = self.env.unwrapped.get_obs()[:self.proprio]
                except:
                    proprio = self.env.unwrapped._get_obs()[:self.proprio]
                if self.lang_cond:
                    emb = np.concatenate([img_emb, lang_emb, proprio])
                else:
                    emb = np.concatenate([img_emb, proprio])

            return emb
        else:
            return observation

    def encode_batch(self, obs, finetune=False):
        ### INPUT SHOULD BE [0,255]
        inp = []
        # pil_images = [Image.fromarray(o.astype(np.uint8)) for o in obs]
        # tensors = [torchvision.transforms.functional.to_tensor(img) for img in pil_images]
        # inp = torch.stack(tensors)
        # inp = inp.to(self.device)
        # inp = self.transforms(inp)
        # if "r3m" in self.load_path:
        #     ## R3M Expects input to be 0-255, preprocess makes 0-1
        #     inp *= 255.0
        for o in obs:
            i = self.transforms(Image.fromarray(o.astype(np.uint8))).reshape(-1, 3, 224, 224)
            if "r3m" in self.load_path:
                ## R3M Expects input to be 0-255, preprocess makes 0-1
                i *= 255.0
            inp.append(i)
        inp = torch.cat(inp)
        inp = inp.to(self.device)
        if finetune and self.start_finetune:
            emb = self.embedding(inp).view(-1, self.embedding_dim)
        else:
            with torch.no_grad():
                # emb =  self.embedding(inp)
                img_emb, lang_emb = self.embedding(inp, self.lang)
                img_emb = img_emb.view(-1, self.embedding_dim).to('cpu').numpy().squeeze()
                if self.lang_cond:
                    lang_emb = lang_emb.view(-1, self.lang_dim).to('cpu').numpy().squeeze()
                    lang_emb = np.expand_dims(lang_emb, axis=0)
                    # lang_emb = torch.repeat_interleave(lang_emb, len(obs), dim=0)
                    # repeat lang_emb to match img_emb
                    lang_emb = np.repeat(lang_emb, len(obs), axis=0)
                    emb = np.concatenate([img_emb, lang_emb], axis=1)
                    # emb = torch.cat([img_emb, lang_emb], dim=1)
                else:
                    emb = img_emb
        return emb

    def get_obs(self):
        if self.embedding is not None:
            return self.observation(self.env.observation(None))
        else:
            # returns the state based observations
            return self.env.unwrapped.get_obs()
          
    def start_finetuning(self):
        self.start_finetune = True


class MuJoCoPixelObs(gym.ObservationWrapper):
    def __init__(self, env, width, height, camera_name, device_id=-1, depth=False, *args, **kwargs):
        gym.ObservationWrapper.__init__(self, env)
        self.observation_space = Box(low=0., high=255., shape=(3, width, height))
        self.width = width
        self.height = height
        self.camera_name = camera_name
        self.depth = depth
        self.device_id = device_id
        if "v2" in env.spec.id:
            self.get_obs = env._get_obs

    def get_image(self):
        if self.camera_name == "default":
            print("Camera not supported")
            assert(False)
            img = self.sim.render(width=self.width, height=self.height, depth=self.depth,
                            device_id=self.device_id)
        else:
            img = self.sim.render(width=self.width, height=self.height, depth=self.depth,
                              camera_name=self.camera_name, device_id=self.device_id)
        img = img[::-1,:,:]
        return img

    def observation(self, observation):
        # This function creates observations based on the current state of the environment.
        # Argument `observation` is ignored, but `gym.ObservationWrapper` requires it.
        return self.get_image()
        